QTM 447 Lecture 27: Interpretable Machine Learning

Kevin McAlister

April 22, 2025

\[ \newcommand\hbb{{\hat{\boldsymbol \beta}}} \newcommand\bb{{\boldsymbol \beta}} \newcommand\expn{{\frac{1}{N} \sum \limits_{i = 1}^N}} \newcommand\sumk{\sum \limits_{k = 1}^K} \newcommand\argminb{\underset{\bb}{\text{argmin }}} \newcommand\argmaxb{\underset{\bb}{\text{argmax }}} \newcommand\gtheta{\mathbf g(\boldsymbol \theta)} \newcommand\htheta{\mathbf H(\boldsymbol \theta)} \]

Clever Hans

Interpretability

Model performance \(\neq\) success

  • A classifier can perform well for the wrong reasons!

Interpretability

Interpretability

Interpretability

Simplicity reigns supreme?

  • COMPAS recidivism algorithm

When we get too simple, we miss the ability to meaningfully make predictions!

  • A tradeoff of complexity and interpretability

How can we make deep learning interpretable?

Interpretability

Who cares?

We all should.

Safety: Detect and prevent critical failures

  • In the face of a failure, learn how to prevent it from occurring again

Example: Autonomous driving misclassifications

Interpretability

Fairness: Reveal and mitigate biases to ensure equitable treatment across groups

  • Why does a prediction get made?

  • How can we determine how to fix problems if our data is inherently biased?

Example: Google Photo Tagging Problems

Interpretability

Privacy/Security: Ensuring that sensitive information in the data is protected

  • Given the richness of some models, is it possible to use it in an adversarial way to make it do bad things?

  • Is it possible to back-engineer private information?

Example: Tay the Chatbot

Interpretability

Legal: We have to

The GDPR’s “right to explanation”

  • People have the right to be given an explanation for an output of an algorithm that has been used to determine an outcome

  • Especially true for decisions that impact your life (credit scores, social scores, etc.)

Opponents say that this requirement will stifle innovation

  • Can you guess who those opponents might be?

Interpretability

Problem: Neural Networks are complicated

\[ \mathbf X \rightarrow g(\cdot) \rightarrow \hat{y} \]

The standard flat NN:

\[ \hat{y} \propto \mathbf W \varphi(\mathbf W \varphi (...)) \]

How can we interpret anything about how a decision was made?

  • A NN is the classic black-box algorithm

Interpretability

We’re going to cover a number of approaches today!

  • Nowhere near all of them.

A rapidly evolving area

Interpretability

A pretty good definition of interpretability (Biran and Cotton, 2017):

Interpretability is the degree to which a human can understand the cause of a decision

Another good one (Kim, Khanna, and Koyejo, 2016):

… a user can correctly and efficiently predict the method’s results.

Interpretability

To what extent can we do this with modern ML methods?

Best place to start is with a class of models that are inherently interpretable

  • Glass-box models

The model can provide all of the info necessary for a human to replicate exactly how a prediction was made with pen and paper

  • Or a computational graph.

What are some examples of glass-box models that you can think of?

Linear Regression

Classic example: Linear Regression

\[ \hat{y} = g(\mathbf X \hat{\boldsymbol \beta}) \]

  • As any \(x_j\) increases, we know exactly how \(\hat{y}\) changes

NOTE: We’re talking about causality in terms of the prediction. Not in terms of actual outcomes.

  • Changes in a value necessarily result in changes in the prediction

  • Doesn’t mean that the change will correspond with a change

Linear Regression

This is an example of a glass-box model because the structure makes it inherently easy to see how things change as a function of each feature!

  • Very interpretable.

Other examples:

  • Simple decision trees

  • Ridge/LASSO

  • Generalized Additive Models

Linear Regression

What is the generic problem with these predictive models?

A trade-off exists:

  • As transparency increases, bias increases

Any intuition as to why?

Linear Regression

Interpretation requires functional specification

  • Can’t easily say that I know how things will change if I don’t know the function that creates that mapping!

Functional specification = Assumptions

  • Or really, comfort with misspecification

Modern machine learning requires models that can uncover the correct form regardless of the form given enough data!

  • Remember what this class of learning methods are called?

Universal Approximators

Modern ML requires the usage of universal approximators

  • Model complex dependencies in a meaningful way

Problem: These methods are typically local in nature!

  • Random Forests: Local rectangles

  • KNN: Local neighborhoods

  • Gaussian SVMs: Local distance structures

  • NNs: Local linear approximations

No one-size fits all method to interpret.

Universal Approximators

What do we do?

Two main approaches:

  • Model agnostic post-hoc interpretability: apply methods using the trained model to better understand how predictions are made

    • Model specific - examine parts of a model to understand (gradients, LIME)

    • Global approaches - examine the behavior of predictions on average using the prediction machine

  • Intrinsic Interpretability: change how the model is estimated to create windows that show how predictions are made

    • Opaque glass boxes

Universal Approximators

While methods exist for other ML algorithms, we’re going to completely concentrate on neural networks

  • Eventually shift to just images and text

Methods like XGBoost suffer from the same issue

  • Can (and should) use some of these methods to better understand the black box nature of the algorithm!

LIME

One approach to interpretation is to only expect that we can explain predictions locally

  • Defense: the whole point is that the model is so complex that local explanations are really the only thing that matters

  • A complicated combination of credit score, income, etc. leads to mortgage approval

Explain individual predictions made by a black box!

LIME

Generic approach:

\(\hat{y} = g(\mathbf x)\)

I don’t know exactly what \(g()\) is, but I can quickly query it.

I want to know how changes in the feature values for my example result in changes to the prediction

LIME

Intuitive approach:

Let \(\mathbf x'\) be a feature vector that is close to \(\mathbf x\) associated with \(g(\mathbf x') = \hat{y}'\) where one value of \(\mathbf x\) is changed

If \(\hat{y}\) and \(\hat{y}'\) are really different, then that feature must have been important in making the prediction!

  • Makes sense

How can we quantify this argument and apply it to higher dimensional settings?

LIME

Let’s get more specific:

We have a target instance \(\mathbf x \in \mathbb R^P\)

  1. Create a set of \(N\) perturbed samples, \(\{\mathbf x_i'\}_{i = 1}^N\)
  2. Pass each perturbed sample through the black box to get \(\{ \hat{y}_i\}_{i = 1}^N\)
  3. Compute a proximity weight for each perturbed sample using a proper kernel function - \(w_i = \exp \left( -\frac{\|\mathbf x - \mathbf x_i'\|^2}{2\sigma^2} \right)\)
  4. Fit a sparse linear surrogate using the LASSO penalty:

\[ \hat{\boldsymbol \beta} = \underset{\{\beta_0 , \boldsymbol \beta \}}{\text{argmin }} \sum \limits_{i = 1}^N w_i(\hat{y}_i - \beta_0 - \mathbf x_i' \boldsymbol \beta)^2 + \lambda \sum \limits_{j = 1}^P | \beta_j | \]

LIME

Local Interpretable Model-Agnostic Explanations is an approach that creates a local linear approximation to the predictive function in the neighborhood of the point of interest

The corresponding coefficients to this weighted LASSO correspond to the loss minimizing importance function conditional on \(\lambda\)

Any thoughts on the intuition of this?

  • What does a large coefficient correspond to here?

LIME

Ultimately, just trying to visualize the predictive surface in the neighborhood around the point!

  • Small changes in \(x_j\) resulting in big changes in \(\hat{y}\) corresponds to important features

Works in any scenario where we can generate a neighborhood of points around the example.

LIME

A little tricky for images and text!

  • We’ve got better tools for those.

For any model estimated using PyTorch, we can use the captum library in Python to run a LIME model!

LIME

Strengths:

  • Explanations are short and easy to understand

  • Easy to implement

Weaknesses:

  • Very sensitive to neighborhood choice

  • “Noising” method can lead to really unlikely data points due to strong dependencies between features

  • Can only be applied one point at a time

Permutation Importance

What if I want a more global measure of feature importance?

  • How sensitive is this prediction to small perturbations?

  • How much does this feature contribute to explaining the predictions for all data in my data set?

Local sensitivity is important for explaining individual predictions

  • Not as important in the case where we want to think about feature importance more in the regression sense

  • Think LASSO

Permutation Importance

Another way to think about feature importance

Including this feature significantly reduces the overall loss of the predictive function

  • Doesn’t answer the question of how a feature contributes to a prediction

  • Does answer the question of how much a feature contributes to a prediction

How could we assess this globally in a model agnostic way?

Permutation Importance

One approach: Leave One Feature Out fitting

  1. Fit the overall model minimizing your loss function. Compute the overall loss on a held out validation set.
  2. For feature \(j \in (1,2,...,P)\):
    1. Leave \(x_j\) out of the model
    2. Refit to minimize the loss function. Compute the loss on the held out validation set.
    3. Subtract the overall loss from the submodel loss
  3. Rank each feature in terms of the overall increase in loss

Problem: Takes way too long!!!!

Permutation Importance

Slightly better approach: permutation importance

  1. Fit the overall model minimizing your loss function. Compute the overall loss on a held out validation set.
  2. For feature \(j \in (1,2,...,P)\):
    1. Randomly permute the values of this feature for your data set
    2. Create predictions using this permutation feature set. Compute the loss on the held out validation set when we permute feature \(j\).
    3. Subtract the overall loss from the submodel loss
  3. Rank each feature in terms of the overall increase in loss.

Permutation Importance

Important global model agnostic method!

  • Can (and should) be applied to any black box model on tabular data!

  • The default method for RFs, XGBoost, and Tabular NNs

Problem: Doesn’t really apply to images and text.

Saliency Maps

LIME and Permutations are good for all models

  • Just hard to set up for images and text since the inputs aren’t really of the form of number = feature

More complex inputs requires more complex methods!

Let’s look at an approach for image classification.

Saliency Maps

Saliency Maps

What makes this a picture of a dog?

Is there an agnostic way that we can know what pixels correspond most heavily with a certain class label?

For image classification problems, the final layer of the CNN produces scores for each image (e.g. the predictions)

\[ \mathbf s = (s_1,s_2,...,s_C) \]

Saliency Maps

By definition, we construct deep learning models in a way that they are fully brackpropable

This means that for any image, \(\mathbf x_i\), we can compute:

\[ \frac{\partial s_j}{\partial \mathbf x_i} \]

  • The gradient of the class scores w.r.t. any score component!

Since the gradient of a scalar w.r.t a complex input takes on the form of the complex input, our gradient of the score is then a pixel-by-pixel set of values that correspond to how much a change in one pixel value (taking into account the convolutional structure) changes the score!

Saliency Maps

The Vanilla Gradient Methods speeds up the process of creating LIME maps for images by using gradients!

Via a first-order Taylor series expansion:

\[ s_c(\mathbf x) \approx \mathbf w^T \mathbf x + b \]

where:

\[ \mathbf w = \frac{\partial s_c}{\partial \mathbf x}\rvert_{\mathbf x_0} \]

  • Just a local linear surrogate model!

Saliency Maps

Vanilla Gradients:

  1. Perform a forward pass of the image of interest (e.g. get the class score)
  2. Compute the gradient of the class score of interest with respect to the input pixels
  3. Visualize the gradient

Saliency Maps

Problem: Vanilla gradients tend to be really sensitive and noisy

  • Also suffer from saturation problems since most sublevels of the gradients after passing through ReLU activations evaluate to zero

Reason: Derivatives fluctuate greatly at small scales

  • No incentive to smooth gradients for a model that fits the training data well

  • Overfit w.r.t. to billions of parameters, but works well when put together

Saliency Maps

Solution: SmoothGrad

  1. Generate multiple versions of the image of interest by adding noise to it
  2. Perform a forward pass of the image of interest (e.g. get the class score)
  3. Compute the gradient of the class score of interest with respect to the input pixels
  4. Average the pixel attribution maps

Adding a little noise to see where gradients are randomly fluctuating will help to smooth the process!

Saliency Maps

Vanilla Gradients work decently, but can be jumpy due to the pixel level feature importance mapping

An alternative approach attempts to use the feature maps present in CNNs to apply strict spatial semantic continuity in the salience map!

Saliency Maps

Saliency Maps

The final layer before the fully connected classification head carries a lower dimensional, denoised representation of the original image!

  • The final feature map theoretically contains information about features within the image that is semantically meaningful.

Alter the approach to only compute the gradient at the final convolutional layer and upscale that back to the original size of the image.

  • How does this work when our final convolutional layer is of size \(4 \times 4 \times 512\)?

Saliency Maps

Gradient Weighted Class Activation Mappings (Grad-CAM) does the following:

  1. Compute the class scores for an input image

  2. Assume that the final convolutional layer yields a tensor of size \(k \times k \times D\). For each \(k \times k\) filter, compute the gradient of the score w.r.t that filter.

  3. Across all \(D\) filters, globally average pool the “pixels” in the gradient. This finds which of the \(D\) filters have a large influence on the class label. This yields a filter weight, \(\alpha_d\)

  4. Create the coarse heatmap

    \[ \mathbf x_{ij} = \text{ReLU}\left(\sum \limits_{d = 1}^D \alpha_d x_{i,j,d} \right) \]

  5. Bilinearly upscale the heatmap to the original image size

Saliency Maps

This approach is often too coarse, so guide the approach by combining Vanilla Gradients with Grad-CAM to get guided Grad-CAM.

Simple fix:

  • Apply both Vanilla Gradients and Grad-CAM. Elementwise multiply the absolute value of the Vanilla gradient with the interpolated Grad-CAM representation at the end!

Works decently well to get pixel level maps that explain why an image was classified as a dog!

Attention Visualization

Saliency maps work well for images. What can we do for text?

Broadly, the only models worth thinking about here are those that rely on self-attention

  • Why talk about hamburgers when we access to the finest prime rib?

Let’s think about how we might be able to do something similar for text classification using BERT

Attention Visualization

Attention Visualization

Attention Visualization

Attention Visualization

Goal: Given an input sentence, \(\mathbf x\), understand how words relate to one another and ultimately lead to a particular classification!

Solution: Check out the attention weights w.r.t to the classification token

  • Problem: Multiheaded attention means that there are multiple attention heads in each layer

  • Problem: Multiple layers of attention - base BERT has 12!!!!!

Attention Visualization

\[ \begin{array}{c|ccccccc} & \mathrm{[CLS]} & I & \mathrm{loved} & \mathrm{this} & \mathrm{movie} & ! & \mathrm{[SEP]} \\ \hline \mathrm{[CLS]} & 0.05 & 0.03 & 0.30 & 0.05 & 0.40 & 0.05 & 0.12 \\ I & 0.01 & 0.10 & 0.75 & 0.05 & 0.03 & 0.03 & 0.03 \\ \mathrm{loved} & 0.02 & 0.01 & 0.05 & 0.60 & 0.25 & 0.04 & 0.03 \\ \mathrm{this} & 0.01 & 0.02 & 0.03 & 0.10 & 0.80 & 0.02 & 0.02 \\ \mathrm{movie} & 0.02 & 0.01 & 0.02 & 0.10 & 0.05 & 0.80 & 0.00 \\ ! & 0.05 & 0.05 & 0.05 & 0.05 & 0.05 & 0.05 & 0.70 \\ \mathrm{[SEP]} & 0.80 & 0.04 & 0.04 & 0.04 & 0.04 & 0.02 & 0.02 \\ \end{array} \]

Attention Visualization

Each head and layer is associated with its own attention matrix.

  • Can view them all, but this is going to be unruly.

Instead, aggregate in some meaningful way.

  • Aggregate within layer by averaging attention weight over all heads

  • Aggregate across model by averaging layer averages over all layers

Lose clarity!

Note: At each aggregation step, renormalize so that rows add up to 1!

Attention Visualization

More clever aggregation via roll-out

Transformers add a residual (identity) connection around each attention block. We can mimic this by defining an augmented layer-wise attention matrix (averaged over layer):

\[ \tilde{\mathbf A}^\ell = \mathbf A^{\ell} + \mathcal I \]

  • Why add identity?

  • Each token actually is allowed to attend to itself via a skip path. The raw layer attention matrices don’t take this connection into account.

  • Keeps the attention model from collapsing on itself.

Attention Visualization

After re-normalization (approximating what a transformer does), we can define a rolled-out version of attention weights that track attention across the entire run on the self attention layers as:

\[ \mathbf R = \hat{\mathbf A^{(1)}} \times \hat{\mathbf A^{(2)}} \times ... \times \hat{\mathbf A^{(L)}} \]

\(\mathbf R\) tells us the total mass of attention flowing from token \(i\) in the input all the through the network to token \(j\)

  • High values mean that there is a lot of “attention” given to \(j\) by \(i\) in the network

  • BERT’s class token, then can track attention to the class label all the way through!

Attention Visualization

Cost: Aggregation loses specificity!

  • Small local patterns can be lost due to aggregation

  • If only our eyes could reasonably see in 120 dimensions at once…

Causal importance (in the predictive sense) is not directly encoded here

  • Too aggregated generally

  • But, can be useful for locating general patterns

Post-Hoc Importance

All of these methods are post-hoc importance metrics.

  • Train the model

  • Apply a method

  • Tease out importance from the model

What do you think is the big weakness of this approach?

Post-Hoc Importance

The clarity of explanations from regression models comes from their construction!

  • Clarity is a part of the model structure and loss functions

  • Reward clarity in model training

Post-Hoc Importance

Think about house-training a puppy

  • When the puppy doesn’t pee in the house, we want to reward it

  • Two strategies:

    • Give the dog treats at the end of the day when it doesn’t pee in the house

    • Give the dog treats immediately after it pees outside. Give light negative reinforcement when the puppy pees in the house.

  • Which would work better?

Glass-box models

Ultimately, we can’t expect a model to learn interpretable structures if we don’t tell it to do so!

Regression approaches:

  • Add L1 penalties on the weights

  • Simplify model form

Why won’t these work for modern deep learning architectures?

Glass-box models

Clever change 1: Mixtures of Experts for FFNNs

Consider a flat layer in a neural network with 4000 hidden units

  • Very expressive

We follow up each layer with a ReLU activation function.

  • How many units do we hope map to zero after applying the ReLU transformation?

Glass-box models

Remember that a weight of zero (or hidden representation) is equivalent to not including it in the model at all

  • Why L1 penalties are so nice!

Can we restructure our NN architecture to impose this sparsity constraint in a nicer way?

Glass-box models

Consider dividing the 4000 hidden units into 10 sets of 400.

At full capacity, switch on all 10 blocks at once.

If we train the model right, though, each block of 400 units can be encouraged to build expertise for a particular type of input!

  • Still allow functional variation to exist

  • But, make variation for larger scale concepts fall into one or more of the blocks!

Glass-box models

General strategy:

  1. For each feedforward layer, break the total number of hidden units into equally sized blocks of experts (K experts)
  2. For each input, pass the input through a routing function which chooses \(k < K\) experts to send the unit through
  3. Create hidden representations by passing the input through the \(k\) experts and proceed as normal!

Glass-box models

Rationale for the Mixture of Experts approach:

  1. Our hope is that each hidden unit starts to correspond to something meaningful about the inputs. Instead of thoughts and prayers, explicitly bake that into the model!
  2. Ensemble learning: remember boosted trees? A lot of simple models that are averaged can outperform one complex model. Reduces the model variance and allows for better prediction.
  3. Different experts get good at different things! Explicitly allocate 400 units just to telling dog ears from cat ears.

Glass-box Models

Glass-box Models

Significant decrease in training time for insanely large models

  • When backpropping for a specific training example, each gradient w.r.t. to a layer corresponds only to elements that are not fixed

  • Consider a MoE that chooses the top 2 experts for each example

    • Total operations for full model: 4000 x N

    • Total operations for MoE: (2 x 400) x N x Router Network

  • Same number of overall parameters with less train time needed!

Glass-box Models

Glass-box Models

Glass-box Models

Glass-box Models

Glass-box Models

Through specialization, we get interpretability and scalability together

Important point: these two things don’t have to be separate!

  • If we design it to be interpretable, it’ll be interpretable.

  • Just need more clever design strategies.

Glass-box Models

Concept bottleneck models

  • Predict a set of human‑interpretable concepts (e.g., object attributes or high‑level features) and then use those concept predictions as inputs to a simple classifier for the final task

Glass-box Models

Glass-box Models

KNN Retrieval Models

  • At inference time, retrieve the top‑k most similar examples from a stored memory of training data in the encoder space and aggregate their labels (or representations) to make a prediction

  • Sorta like a deep KNN model

  • Near examples = some amount of interpretability!

Glass-box Models

Glass-box Models

This is the way forward to ensuring that modern AI is human-interpretable

  • Gotta teach the puppy not to pee in the house!

  • It won’t just learn it on its own.

Big money in this area right now.

  • Deepseek is a true disruptor. Smaller cost, better predictions.